{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "rl-baselines-zoo.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/master/rl_baselines_zoo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XJy9QoDC7XA7",
        "colab_type": "text"
      },
      "source": [
        "# RL Baselines Zoo: Training in Colab\n",
        "\n",
        "\n",
        "\n",
        "Github Repo: [https://github.com/araffin/rl-baselines-zoo](https://github.com/araffin/rl-baselines-zoo)\n",
        "\n",
        "Stable-Baselines Repo: [https://github.com/hill-a/stable-baselines](https://github.com/hill-a/stable-baselines)\n",
        "\n",
        "Medium article: [https://medium.com/@araffin/stable-baselines-a-fork-of-openai-baselines-df87c4b2fc82](https://medium.com/@araffin/stable-baselines-a-fork-of-openai-baselines-df87c4b2fc82)\n",
        "\n",
        "# Install Dependencies\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AXVDDlTn02M9",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Stable Baselines only supports tensorflow 1.x for now\n",
        "%tensorflow_version 1.x\n",
        "!apt-get update\n",
        "!apt-get install swig cmake libopenmpi-dev zlib1g-dev ffmpeg freeglut3-dev xvfb\n",
        "!pip install stable-baselines[mpi] --upgrade\n",
        "!pip install pybullet\n",
        "!pip install box2d box2d-kengz pyyaml pytablewriter optuna scikit-optimize\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kDjF3qRg7oGH",
        "colab_type": "text"
      },
      "source": [
        "## Clone RL Baselines Zoo Repo"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SCjGikdT1DFy",
        "colab_type": "code",
        "outputId": "8cf923aa-c025-4ab0-cf2f-473e648c0c30",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 159
        }
      },
      "source": [
        "!git clone https://github.com/araffin/rl-baselines-zoo"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Cloning into 'rl-baselines-zoo'...\n",
            "remote: Enumerating objects: 102, done.\u001b[K\n",
            "remote: Counting objects: 100% (102/102), done.\u001b[K\n",
            "remote: Compressing objects: 100% (83/83), done.\u001b[K\n",
            "remote: Total 1506 (delta 34), reused 56 (delta 12), pack-reused 1404\u001b[K\n",
            "Receiving objects: 100% (1506/1506), 375.23 MiB | 15.12 MiB/s, done.\n",
            "Resolving deltas: 100% (866/866), done.\n",
            "Checking out files: 100% (301/301), done.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "REMQlh-ezyVt",
        "colab_type": "code",
        "outputId": "72803a67-d200-4c1f-ecfa-74de4e773804",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "cd rl-baselines-zoo/"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/content/rl-baselines-zoo\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6gJ-pAbF7zRZ",
        "colab_type": "text"
      },
      "source": [
        "## Train an RL Agent\n",
        "\n",
        "\n",
        "The train agent can be found in the `logs/` folder.\n",
        "\n",
        "Here we will train A2C on CartPole-v1 environment for 100 000 steps. \n",
        "\n",
        "\n",
        "To train it on Pong (Atari), you just have to pass `--env PongNoFrameskip-v4`\n",
        "\n",
        "Note: You need to update `hyperparams/algo.yml` to support new environments. You can access it in the side panel of Google Colab. (see https://stackoverflow.com/questions/46986398/import-data-into-google-colaboratory)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9bIR_N7R11XI",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!python train.py --algo a2c --env CartPole-v1 --n-timesteps 100000"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-fHBq73665yD",
        "colab_type": "text"
      },
      "source": [
        "#### Evaluate trained agent\n",
        "\n",
        "\n",
        "You can remove the `--folder logs/` to evaluate pretrained agent."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Bw8YuEgU6bT3",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!python enjoy.py --algo a2c --env CartPole-v1 --no-render --n-timesteps 5000 --folder logs/"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w5Il2J0VHPLC",
        "colab_type": "text"
      },
      "source": [
        "#### Tune Hyperparameters\n",
        "\n",
        "We use [Optuna](https://optuna.org/) for optimizing the hyperparameters.\n",
        "\n",
        "Tune the hyperparameters for PPO2, using a tpe sampler and median pruner, 2 parallels jobs,\n",
        "with a budget of 1000 trials and a maximum of 50000 steps"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "w2sC22eGHTH-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!python -m train.py --algo ppo2 --env MountainCar-v0 -n 50000 -optimize --n-trials 1000 --n-jobs 2 --sampler tpe --pruner median"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xVm9QPNVwKXN",
        "colab_type": "text"
      },
      "source": [
        "### Record  a Video"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MPyfQxD5z26J",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Set up display; otherwise rendering will fail\n",
        "import os\n",
        "os.system(\"Xvfb :1 -screen 0 1024x768x24 &\")\n",
        "os.environ['DISPLAY'] = ':1'"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HS1VBBaQ_emT",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install pyglet==1.3.1  # pyglet v1.4.1 throws an error"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ip3AauLzwNGP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!python -m utils.record_video --algo a2c --env CartPole-v1 --exp-id 0 -f logs/ -n 1000"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qBuUfnzI8DN6",
        "colab_type": "text"
      },
      "source": [
        "### Display the video"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZC3OTfpf8CXu",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import base64\n",
        "from pathlib import Path\n",
        "\n",
        "from IPython import display as ipythondisplay\n",
        "\n",
        "def show_videos(video_path='', prefix=''):\n",
        "  \"\"\"\n",
        "  Taken from https://github.com/eleurent/highway-env\n",
        "\n",
        "  :param video_path: (str) Path to the folder containing videos\n",
        "  :param prefix: (str) Filter the video, showing only the only starting with this prefix\n",
        "  \"\"\"\n",
        "  html = []\n",
        "  for mp4 in Path(video_path).glob(\"{}*.mp4\".format(prefix)):\n",
        "      video_b64 = base64.b64encode(mp4.read_bytes())\n",
        "      html.append('''<video alt=\"{}\" autoplay \n",
        "                    loop controls style=\"height: 400px;\">\n",
        "                    <source src=\"data:video/mp4;base64,{}\" type=\"video/mp4\" />\n",
        "                </video>'''.format(mp4, video_b64.decode('ascii')))\n",
        "  ipythondisplay.display(ipythondisplay.HTML(data=\"<br>\".join(html)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oKOjFuwK9HI0",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "show_videos(prefix='a2c')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RjdpP0HE8D2p",
        "colab_type": "text"
      },
      "source": [
        "### Continue Training\n",
        "\n",
        "Here, we will continue training of the previous model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zgMZQJJF6u1C",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!python train.py --algo a2c --env CartPole-v1 --n-timesteps 50000 -i logs/a2c/CartPole-v1.zip"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GSaoyiAE8cVj",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!python enjoy.py --algo a2c --env CartPole-v1 --no-render --n-timesteps 1000 --folder logs/"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jL9u4I1H-48O",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}
